{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-Gaussian Likelihoods\n",
    "\n",
    "## Introduction\n",
    "\n",
    "This example is the simplest form of using an RBF kernel in an `ApproximateGP` module for classification. This basic model is usable when there is not much training data and no advanced techniques are required.\n",
    "\n",
    "In this example, we’re modeling a unit wave with period 1/2 centered with positive values @ x=0. We are going to classify the points as either +1 or -1.\n",
    "\n",
    "Variational inference uses the assumption that the posterior distribution factors multiplicatively over the input variables. This makes approximating the distribution via the KL divergence possible to obtain a fast approximation to the posterior. For a good explanation of variational techniques, sections 4-6 of the following may be useful: https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up training data\n",
    "\n",
    "In the next cell, we set up the training data for this example. We'll be using 10 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels. Labels are unit wave with period 1/2 centered with positive values @ x=0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = torch.linspace(0, 1, 10)\n",
    "train_y = torch.sign(torch.cos(train_x * (4 * math.pi))).add(1).div(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the classification model\n",
    "\n",
    "The next cell demonstrates the simplest way to define a classification Gaussian process model in GPyTorch. If you have already done the [GP regression tutorial](../01_Exact_GPs/Simple_GP_Regression.ipynb), you have already seen how GPyTorch model construction differs from other GP packages. In particular, the GP model expects a user to write out a `forward` method in a way analogous to PyTorch models. This gives the user the most possible flexibility.\n",
    "\n",
    "Since exact inference is intractable for GP classification, GPyTorch approximates the classification posterior using **variational inference.** We believe that variational inference is ideal for a number of reasons. Firstly, variational inference commonly relies on gradient descent techniques, which take full advantage of PyTorch's autograd. This reduces the amount of code needed to develop complex variational models. Additionally, variational inference can be performed with stochastic gradient decent, which can be extremely scalable for large datasets.\n",
    "\n",
    "If you are unfamiliar with variational inference, we recommend the following resources:\n",
    "- [Variational Inference: A Review for Statisticians](https://arxiv.org/abs/1601.00670) by David M. Blei, Alp Kucukelbir, Jon D. McAuliffe.\n",
    "- [Scalable Variational Gaussian Process Classification](https://arxiv.org/abs/1411.2005) by James Hensman, Alex Matthews, Zoubin Ghahramani.\n",
    "  \n",
    "In this example, we're using an `UnwhitenedVariationalStrategy` because we are using the training data as inducing points. In general, you'll probably want to use the standard `VariationalStrategy` class for improved optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import ApproximateGP\n",
    "from gpytorch.variational import CholeskyVariationalDistribution\n",
    "from gpytorch.variational import UnwhitenedVariationalStrategy\n",
    "\n",
    "\n",
    "class GPClassificationModel(ApproximateGP):\n",
    "    def __init__(self, train_x):\n",
    "        variational_distribution = CholeskyVariationalDistribution(train_x.size(0))\n",
    "        variational_strategy = UnwhitenedVariationalStrategy(\n",
    "            self, train_x, variational_distribution, learn_inducing_locations=False\n",
    "        )\n",
    "        super(GPClassificationModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "        return latent_pred\n",
    "\n",
    "\n",
    "# Initialize model and likelihood\n",
    "model = GPClassificationModel(train_x)\n",
    "likelihood = gpytorch.likelihoods.BernoulliLikelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model modes\n",
    "\n",
    "Like most PyTorch modules, the `ExactGP` has a `.train()` and `.eval()` mode.\n",
    "- `.train()` mode is for optimizing variational parameters model hyperameters.\n",
    "- `.eval()` mode is for computing predictions through the model posterior."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learn the variational parameters (and other hyperparameters)\n",
    "\n",
    "In the next cell, we optimize the variational parameters of our Gaussian process.\n",
    "In addition, this optimization loop also performs Type-II MLE to train the hyperparameters of the Gaussian process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1/50 - Loss: 0.908\n",
      "Iter 2/50 - Loss: 4.296\n",
      "Iter 3/50 - Loss: 8.896\n",
      "Iter 4/50 - Loss: 3.563\n",
      "Iter 5/50 - Loss: 5.978\n",
      "Iter 6/50 - Loss: 6.632\n",
      "Iter 7/50 - Loss: 6.222\n",
      "Iter 8/50 - Loss: 4.975\n",
      "Iter 9/50 - Loss: 3.972\n",
      "Iter 10/50 - Loss: 3.593\n",
      "Iter 11/50 - Loss: 3.329\n",
      "Iter 12/50 - Loss: 2.798\n",
      "Iter 13/50 - Loss: 2.329\n",
      "Iter 14/50 - Loss: 2.140\n",
      "Iter 15/50 - Loss: 1.879\n",
      "Iter 16/50 - Loss: 1.661\n",
      "Iter 17/50 - Loss: 1.533\n",
      "Iter 18/50 - Loss: 1.510\n",
      "Iter 19/50 - Loss: 1.514\n",
      "Iter 20/50 - Loss: 1.504\n",
      "Iter 21/50 - Loss: 1.499\n",
      "Iter 22/50 - Loss: 1.500\n",
      "Iter 23/50 - Loss: 1.499\n",
      "Iter 24/50 - Loss: 1.492\n",
      "Iter 25/50 - Loss: 1.477\n",
      "Iter 26/50 - Loss: 1.456\n",
      "Iter 27/50 - Loss: 1.429\n",
      "Iter 28/50 - Loss: 1.397\n",
      "Iter 29/50 - Loss: 1.363\n",
      "Iter 30/50 - Loss: 1.327\n",
      "Iter 31/50 - Loss: 1.290\n",
      "Iter 32/50 - Loss: 1.255\n",
      "Iter 33/50 - Loss: 1.222\n",
      "Iter 34/50 - Loss: 1.194\n",
      "Iter 35/50 - Loss: 1.170\n",
      "Iter 36/50 - Loss: 1.150\n",
      "Iter 37/50 - Loss: 1.133\n",
      "Iter 38/50 - Loss: 1.117\n",
      "Iter 39/50 - Loss: 1.099\n",
      "Iter 40/50 - Loss: 1.079\n",
      "Iter 41/50 - Loss: 1.056\n",
      "Iter 42/50 - Loss: 1.033\n",
      "Iter 43/50 - Loss: 1.011\n",
      "Iter 44/50 - Loss: 0.991\n",
      "Iter 45/50 - Loss: 0.974\n",
      "Iter 46/50 - Loss: 0.959\n",
      "Iter 47/50 - Loss: 0.945\n",
      "Iter 48/50 - Loss: 0.932\n",
      "Iter 49/50 - Loss: 0.919\n",
      "Iter 50/50 - Loss: 0.906\n"
     ]
    }
   ],
   "source": [
    "# this is for running the notebook in our testing framework\n",
    "import os\n",
    "smoke_test = ('CI' in os.environ)\n",
    "training_iterations = 2 if smoke_test else 50\n",
    "\n",
    "\n",
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "# num_data refers to the number of training datapoints\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, train_y.numel())\n",
    "\n",
    "for i in range(training_iterations):\n",
    "    # Zero backpropped gradients from previous iteration\n",
    "    optimizer.zero_grad()\n",
    "    # Get predictive output\n",
    "    output = model(train_x)\n",
    "    # Calc loss and backprop gradients\n",
    "    loss = -mll(output, train_y)\n",
    "    loss.backward()\n",
    "    print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make predictions with the model\n",
    "\n",
    "In the next cell, we make predictions with the model. To do this, we simply put the model and likelihood in eval mode, and call both modules on the test data.\n",
    "\n",
    "In `.eval()` mode, when we call `model()` - we get GP's latent posterior predictions. These will be MultivariateNormal distributions. But since we are performing binary classification, we want to transform these outputs to classification probabilities using our likelihood.\n",
    "\n",
    "When we call `likelihood(model())`, we get a `torch.distributions.Bernoulli` distribution, which represents our posterior probability that the data points belong to the positive class.\n",
    "\n",
    "```python\n",
    "f_preds = model(test_x)\n",
    "y_preds = likelihood(model(test_x))\n",
    "\n",
    "f_mean = f_preds.mean\n",
    "f_samples = f_preds.sample(sample_shape=torch.Size((1000,))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAADDCAYAAACVmTQ/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAbPklEQVR4nO3dfXxU1bno8d/khSQEQ6AkvIgIYrGIwjn2OR+QgwjK4VCgKiiWT6tePli9ChSpkBC9YMAASYH2HDge0Sj39mhrCx5qK6U9wMXegry68ChQa5UohEJiBEyjvCQhmfvHzJBJmMnssCczezbP95/M7Nmz9pM1M8/ee+211/J4vV6UUiqSpHgHoJRKDJoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWpNgtQET6A0uAd4HewCljzLMt1kkHVgLHga8DJcaYj+xuWykVO9E4sugK/NIYs8IY8wQwVUS+2WKdOUC5MaYY+BdgbRS2q5SKIdvJwhjzjjHmNy3KPNNitQnAbv/6B4EhIpJld9tKqdixfRoSTEQmAZuNMR+2eCkX+DLoeY1/WU24sgoKCrRrqVJxUlJS4mm5LGrJQkRGA6PxnXK0VAVcFfQ8y7+sVYsXL4643aqqKnJzcy1GGXtOjw+cH6PT4wN3xVhYWBhyeVSShYhMAG4DngB6isi1wF+AC8aYGmATcCuwQ0RuBt73L1dKJYhoXA35JrAOMMAfgEzg34FJwGmgBFgFrBSRBcD1wMN2t6uUii3bycIYsx/oFGGdc8BMu9tS7nfhwgUqKyupra0lcEd0Y2MjNTXOPhBNxBg9Hg9paWn06NGDlJTIqSCqDZxK2VVZWUlmZiZXX301Ho+vja2+vp7U1NQ4R9a6RIzR6/VSXV1NZWUlvXv3jvh+7cGpHKW2tpbs7OyLiUK1H4/HQ3Z2NrW1tZbW12ShHMXr9WqiiCGPx4PVAbA0WaiEVVFRwZgxY6isrLzsMvbv309eXh5FRUU8+eSTlJWVAbB8+XK+//3vRytUy86cOcPUqVN55ZVXmi3/wx/+wIABA8jLy6OwsJCnn36as2fPhi3nzTffpLq6OqqxabJQCau4uJhdu3axbNmyy3p/TU0NS5YsoaSkhIULF1JUVMTMmTNpaGjg/vvvj3K01mRmZjJ+/PhLlo8ePZprr72WmTNnsnjxYkaMGEF+fn7YctojWWgDp0o42dnZnD9//uLz0tJSSktLSU9Pb9MP5Pe//z3Dhg0jOTkZ8P1Q+/Xrx759++jZsyfHjh2jpKSE/fv3s3DhQvr378/s2bMZOHAgZWVlrFmzhl//+tfs3r2b+vp67rnnHtLT05k2bRrf+c532L17Nw899BDPPvssGzZsoLy8nNWrV7NhwwYWLFhAr169KC8vZ+nSpTQ2NjJjxgxuuOEGjh07xm233dZq7OPHj2f+/Pk0NjayfPly6uvrKS8vZ968eaSkpHDgwAGee+45vvWtb9HQ0MDvfvc7kpKSGDFiBJMnT76setdkoRLOn//8ZwoKCnjzzTc5d+4cGRkZ3H333ZSUlLSpnL/+9a987Wtfa7asW7duHD9+nJ49ewJQUFDA4cOHyc/Pp6ioiC+//JInnniC9957j4aGBpYsWcI777zDmTNnuOuuu3jrrbcYPnw4t99+O3PmzOGrr76iV69e9O3blxMnTrBo0SK2bt1KZmYm8+bN47XXXuPVV1+ltraWO++8k+nTp1NUVGQp/i5dunDy5EmGDRvGqFGjOHDgAC+//DIrVqxg8ODBzJo1i759+7J3715KSkpISkri7rvv1mShrhw9e/YkKyuL2tpa0tPTqa2tJSsrix49erSpnN69e3PkyJFmyz7//HOuueaai68D9OnTh7KyMgYNGsSwYcMYPXo048ePp0+fPtTU1LBy5UoaGhro3r37xXL69u1L586d6dy5M3fddRcbN27k8OHDLFiwgJUrV/Lxxx+zYsUKqqurufrqqykrK+Oee+4BuLj9SKqrq+nWrRs1NTU888wzNDQ08MUXX1yyXmZmJs8++yxZWVmcPn26TXUUTNssVEKqqqrikUceYfv27TzyyCN89tlnbS5j3Lhx7N69m4aGBsDXuHjs2DFEBIDjx48DUF5eznXXXcenn37K5MmT2b59O9u2bcPj8dC9e3fy8vKYO3cu3/3ud0Nu5/777+eVV14hIyMDgBtuuIHBgweTl5fHnDlzuOWWW+jfvz9Hjx4FfEc8kWzevJlRo0aRlJREXl4eixYtYvr06RdfT05Oxuv18uGHH1JYWMjUqVOZO3cuHTt2bHM9BeiRhUpI69atu/h41apVl1VG586dWbhwIQUFBWRlZfG3v/2N1atXk5yczPr16zl37hxFRUW8//77FBYWUldXR3FxMf369WPw4MHk5uaSn59PQUEBycnJDB48mLKyMg4dOkRpaenFhtdevXrR2NjIhAkTAJgwYQI7d+6kuLiYiooK5s+fzze+8Q0ef/xxjh07xscff8yhQ4e49957yczMBGD79u2Ul5fzwgsvkJmZyblz5y6edo0bN47Zs2eTnZ3NoUOHKCsrY/jw4axYsYKBAwcyadIkFi9ezLBhwzh+/DhvvfUWd9xxR5vry+PUSYYKCgq8etdpbDgpxsOHD3P99dc3W5aIvSOdKFyMLeu8sLAw5C3qehqilLJEk4VSyhJNFkopSzRZKKUs0WShlLJEk4W6Ym3ZsoUBAwbw4osvNls+duxYpkyZYusGNTfSZKGuWGPHjmXs2LG8/PLLF2/T/uCDDzh9+jTf/va329wj1O20U5ZypPT0tKBnaWHXs+L8+fCDu3Tv3p2BAweybds2xowZw8aNG5k4cSIAO3fu5PXXX6dTp04MGTKEKVOmMG/ePHJycvjoo48oLi7myJEjTJs2jQceeIADBw4wdOhQfvjDH9qK16miNbp3D3xTGA4xxvxDiNdHAf8KBG4J3GSMWRGNbStl14wZM1i+fDkiQlZWFnV1dQDk5+ezefNmOnXqxMiRI7nvvvuYOHEio0aNYuPGjaxfv55Zs2YxfPhwhg4dylNPPcXo0aM1WUQwAvgN8HetrDPHGPP/orQ95XLBRwPt3Tty2LBhnDx5kpKSEvLy8nj++ecB3z0aa9asAXw3ldXU1FBWVsbbb7/NiRMnLt6ZCr6bzTwej+N7cdoRlWRhjPlP/9FDax4U3x06WcBLxphj0di2UtHw2GOPsWPHjma3rF933XXMnj2btLQ0fvvb33LhwgVKS0vZu3cv27ZtY9euXRfXvRKGAoxVm8UHQJEx5oiIDAK2isiNxpjG1t5UVRVx0rKojwYUbU6PD5wVY2NjI/X19c2WBe4Kjba3336bHTt2sG3bNqZMmcKUKVP46KOP2LFjBwcPHmTx4sXk5+eTnZ1NTk4OY8eOpW/fvsydO5czZ87wySefsGfPHg4ePMgbb7zBLbfcQnl5OVu3bmXUqFHtErMd4eqxsbHR0m8tJsnCGFMV9PhPIpINXAMcbe19Vm9ucspNUOE4PT5wTow1NTUhD+Xb4/B+9OjRjB49utmyAQMGsGXLlovPR44c2ez14LtdA/bu3XvxVOkvf/lL1OOMplD1mJSUZOnzb7dLpyKSKSI5/scFItLV/7gr0AFo+wAESqm4idbVkNuBB/HNc7oA+DEwDbgZeAw4AqwSkQ+AG4EHjTHnQ5emlHKiaDVw/hH4Y4vF/x70+i+BX0ZjW8rdAvNYXAkNhk7QlrrWHpzKUdLS0qiurrY88Y26fIHpC9PSrHV60x6cylF69OhBZWUlp0+fbjYxclKSs/driRhj8MTIVmiyUI6SkpJyySS9Thr2L5wrIUZnp0KllGNoslBKWaLJQilliSYLpZQlmiyUUpZoslBKWaLJQilliSYLpZQlmiyUUpZoslBKWaLJQilliSYLpZQlmiyUUpZoslBKWaLJQilliSYLpZQlCZ0sKioq2mW264qKCsaMGaOzaPu1V31oPTfXnvURjd9KVJKFiPQQkZdF5J0wryeJSImILBCRl0RkWDS2W1xczL59+1i2bFk0imtW7q5du6JebqJqr/rQem6uPesjGr8VTzQGRhWR+4BaoNAYIyFenwqMNMbM8M8bsgcYaIwJO9VUQUGBd/HixSFfy87O5vz54cClE74ETz/XVqdOnQr7Wrhyr7/ey+bN9aSnh36fE4ZbO3MGxo1L5ZNPQo/iHG78yMupDyvaWm64+DIyoLS0njvuiP/gvq19zkuXJrNmTTLhfmrtVc+Xlr0OmAVAenp62JnoCgsLKSkpueTLEqu5TicAW/zrnhaR88Ag4EBr5YabUm3Hjh384Aeb2LOn2yWvtVLvFlxaXqRyT53ysGPHFwwZUh/ydSdMDbh/fwfeead7K2skh1ne9vqwpq3lhosP1q+v5aab4l/HrX3Or77ag5MnWxtuv73quWXZnUhPT2fcuHEsWLDA0pSFwWI1YG8u8GXQ8xr/stbfFCZT5+bmcuONa9m7N5fU1FTq6+t58MEHWLas2HagTz31FD/72c/o0KEDdXV1rZY7eXIq+/YlkZHRldzc8Hu3eB9ZpKX5vqi33trI669fmtROnjxJt26hv7BtqY+2aEu5oeL7+c+TmT8/BY+nI7m5HWzHEw3hPue6Ot/PbN++Onr1Cv09aa96Di47NbWRuro6cnNzuemmm9pcTqySRRVwVdDzLP+yy3bqVAWPPjqZSZMm8cYbb1BZ+Qlhvu9tUlPzCY8+OpmHH36YtWvXtlruVf7/6Nw5+9ttT+fP+5JF587ekP9LY2Nj2P+xLfXRFm0pN1R8OTm+H53T6x6aYrz6ai/hziraq56Dy276rVxeI2e7JQsRyQQ6GmM+BzYBI4FX/W0W6cCf7JQfmKC2qqoqqjNWB098u2rVqlbXzcgIfGE9QPzPm8M5e9b3NyOj7e9tS33EstzA/5IIycJK/bdXPQeXbfe3Equ5TtcDfy8ihUAf4KHWGjcTRaJ8YQPxXU6ycKqOHX1/fYnauRobobbWF2O4RvBEEau5ThuB+dHYlpMEvrCBPYdTBZJFIF43CBzVJUrdZ2R4cfiEZRElePjxlThHFr49W+AH5gaJU/e+v244qtNkYUPHjsFtFs5lp83CqZpOQ+IbRyRuqntNFjYEzkET51A4vnFEU3p64DTE2Yk6cCXKDUd1mixsCOzdzp+PbxyRuLHNIlHqPrAjcUPda7KwIbCndvreLRCfG/ZuAU11H984ItHTEAUE97OIcyARuPE0JFHaLNxU95osbEiUvZsbT0NSUyEpyUt9vYcLF+IdTXiBxu9AY3gi02RhQ6KcN7tp7xbg8STG0YWb6l6ThQ1NHYO0zSIeEuHITtssFJAYX1ZoOvJxwxc2WCLUv5vqXpOFDYlyGuKmy3fBAu0Agb4MThQ4qtM2iytcIuzZwF3nzcESof4DdZ/oN5GBJgtbmt+i7lxuvDcEEuP+EDcd1WmysCERvqzgri9ssESofzddttZkYUMiXLoD956GJMKNfIHYAveyJDJNFjYEnzNHYZD0dtHQ4J7BV1pKhBv53HRUp8nChtRUSEnx0tDgoT704N5x13TpzovHuTvgy5IIgw/ppVN1kdPPm920Z2spUPdOvnTddOk0zoFEgSYLm5y+d3PTpbuWAm0WTu5B21T/Dj1PbYNoDdg7BpiMb3h/rzFmcYvXp+EbuDewD1hrjHk1GtuON6cfWbjpRqaWnF734K6rIbaThYh0BF4ABhljakVkg4jcaYzZ1mLVqcaYI3a35zS+vgsex04H4KYva0uaLGIrGkcWtwJHjTG1/uc78U1X2DJZzBKRSqAj8Jwx5nSkgq1Mrxbv6QFTU7sDHThx4gtycuoueT3e8R0/3gHoTkpKfdj6jHeMkYSLr6GhE9CF06fPUVUV3/8hXIxfftkTSOLs2VNUVcX3Xnq7n3M0koWVqQn/CGwyxnwuIuOB14E7IxZscdq/eE4PeNVVvipMS+sSdgrDeMYXmLowKyu11TjiPcViJKHiy8nxNbl5vc6YwjBUjHV1vnlae/fuihOq2M7nHI1kEXFqQmPMp0FP3wLeFJFkN0w01HyyG+edhrhpwNiWnN64DO7qEBeNqyG7gWtFJM3//B+BTSLSVUSyAESkWEQCienrwKduSBTg/KH19NJpfLmp/m0nC2PMWeBxYLWILAEO+Bs3C4AZ/tUqgTUi8jTwNL6pDl3B6Xc+6qXT+LlwAerrPXg8XjrE/yzJtmhNX7gV2NpiWX7Q4+jO9OogTm+Rd1NrfEuBBJgIde+G3rPaKcsmp9/M5OZ+Fk6/kc9tiVqThU1OP7IInB658TRE6z62NFnYlChtFm7ZuwVz+oDJgStRbjmq02RhU6IcCrvh0l1LTq97N43sDZosbHP60HpuHVIPnH8a4rZErcnCJqefhrht7xbM6YMPua3uNVnY5PSOQW5uswgMPtTY6MzBh9x2JUqThU1O73Lc9IWNcyDtxMn1r6chqpnAoCZObbNounznjr1bS05ut9DTENWMk/ds0HR65NYjCye3Gbmt7jVZ2OTkPRu4b+/WkpOvRrltQmpNFjY5/Vq/my+dgrPrX9ssVDNO70Xoti9sS04+DXFb3WuysEkvncZXU/07L1m7re41Wdjk9AZONw2+EkrTmBZxDiQEbbNQzQQ3cDqtF2Fg8JWkJC+pqfGOpn04eUwLtzUua7KwKTkZOnTw4vV6qK2NvH4suW3wlVCcfGSnl07VJZzayOa2PVsoeuk0djRZREFbLt9VVFQwZswYKisroxpDqHLd1hofSri6b696bkvZbqv/qCQLERkjIs+LyCIRKQzxerqIPCciT4nI/xaRAdHYrlM0nTdH3rsVFxeza9culi1bFtUYQpXr9j4W0FT3LY/q2que21K225JFrKYvnAOUG2OWi8jNwFrgNrvbdgpfi7yHpUuT6dat+Q/z7NlsOnZM5vnn19DQcAH4BrCS0lIoLX2e5OQUZsx4/LK33Vq59903wx/fZRfveIH/bcuWJM6cab0+7NQztF72tGnT6Ngxudn6R4+66ya+WE1fOAHfFAAYYw6KyBARyTLG1LRWcCJMXwjQpUsOkM4vfpEc4tXA/Es/CPnehgb4t3+zs/Xw5a5b53ucnV1LVdXJsCU4oQ5b01p8GRkdga/x7rtJvPtuEu1Xz7Ra9tq1od/h8Xjxej+nqqrR7sZtS5TpC8Ot02qySITpCwFeegk2bbpAY4jvw1dffUmnTr6E8atf/Yq9e/eQnJxCQ8MFhg4dxuTJk21vv7Vyk5Jg4sSkiHUU7zqMJFx806dDZmY91dVNp4DtVc+tlR38OQcbOLCRQYO6RWXb0eD46QstrpOw+vWDWbNCT7BWVfUVubm+49CdO1/j0Ud78PDDD7N27VoqK19j9uy7bW+/vcpNBGlp8NBDzbN0e9ZHuLKDP2e3ikayuDh9of9U5B+B50WkK3DBf6qxCd/pyg5/m8X7kU5B3Ghd4LwAWLUqevMutVe5iao96+NKrutYTV+4Cl9CWQDMBR62u12lVGzFavrCc8DMaGxLKRUf2ilLKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSltgag9M/gncJ8AnwdeBpY8xnIdY7AhzxPz1ujPmene0qpWLP7oC9y4D/a4xZLyLfBlYCD4ZY76fGmEU2t6WUiiO7yWICsNT/eCfwH2HWGyki+fgmGvq9MWaXze0qpWIsYrIQkc1A9xAvPUPzaQlrgC4ikmKMudBi3QJjzD7/JMrvishEY8zhSNtOlLlOW+P0+MD5MTo9PrgyYoyYLIwx/xzuNREJTEtYjW9Kwi9CJAqMMfv8f8+KyHv4Zi2LmCwSZa7TSJweHzg/RqfHB+6P0e7VkMC0hOBLAJsARCRJRPr4H98pIuOC3nM9UGZzu0qpGLPbZvE08CMRGQD0B+b5lw8GXgVuxjcB8iIRuQXoBWwwxrxtc7tKqRizlSyMMaeBR0Isfw9fosAYcxC41852lFLxp52ylFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYnd6QuT8I3BWQTcYYw5FGa9McBkfIP3eo0xi+1sVykVe3aPLIYAe4Gz4VbwTyz0AvBD/xSGg0XkTpvbVUrFmK1kYYz5b/9I3q25FThqjKn1P9+Jb9pDpVQCsTV9oTHmTQvbCJ7iEHzTHFqaFqmwsNDKakqpGLA1faFFgSkOA7L8y1pVUlLisbldpVQUtdvVEBHp53+4G7hWRNL8zy9Oc6iUShwer9d72W8WkS7ATGAuvukKXzPG7BGRHOA9oL8x5ryI/BNwH/A5UK9XQ5RKPLaShVLqyqGdspRSlmiyUEpZYqsHZ6xE6gEqIunASuA48HWgxBjzkcNinA/0ACqBb+K79Pyhk2IMWu97wM+Aq4wxXzklPhHxAD/wP+0LZBtjpscqPosx9sP3XXwH+Dt87XhWuhhEK74ewBJgiDHmH0K8ngQsA74CrgXWGmP2WCnb8UcWFnuAzgHKjTHFwL8Aax0YYyfgSWPMj4ANwAoHxoiIDARujGVs/u1aie8BoNoYs9oY8yTwrw6MMR942xhTAvwI+HEsYwRGAL8BwnU9uB/IMsYsAeYDr4hIspWCHZ8ssNYDdAK+S7QYYw4CQ0QkK3YhRo7RGLPQGBNoTU7Cl9ljKWKM/h9DPhCPq1VWPufvAV1FZLaIBPaOsWQlxs+AHP/jHGB/jGIDwBjznzTvBNlS8G/lNHAeGGSl7ERIFlZ6gF52L9Eosbx9EekA/A9gQQziCmYlxqVAkTGmLmZRNbES37X49oqrgZ8C/2V1rxglVmL8CTBURH4CPAP8nxjFZtVl/1YSIVlY6QF6Wb1Eo8jS9v2JYg3wv4wxZTGKLaDVGEXkGqALcL+IFPgXPyki4oT4/Grw3biIv00qC7gmJtH5WInxp8DL/tOkScA6Eekam/AsuezfSiIki5A9QEWka9CpxiZ8h4iIyM3A+8aYGifFKCIZwIvAT4wx+0Xk3hjGFzFGY8wxY8w0Y0yJ/3wbf6zGCfH5l20DrgPwL0vG12AcK1ZivAao8D/+Amgkzr8zEcn0d5SE5r+VrkA68Ccr5SREp6xQPUBFZDlw2hhT4v8hrsT3IV0PLIvD1ZBIMf4KuAk44X9LZqjW6njG6F8nB/if+MYoKQJeNMYcd0J8ItIZWA4cBfoDG4wxv4tFbG2IcQS+Bvd3gX7AfmPMCzGM73bgIWAcvqPYHwPTgZuNMY/5r4YU4xtWog/wktWrIQmRLJRS8ZcIpyFKKQfQZKGUskSThVLKEk0WSilLNFkopSzRZKGUskSThVLKkv8P9WWPNod/4y4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Go into eval mode\n",
    "model.eval()\n",
    "likelihood.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Test x are regularly spaced by 0.01 0,1 inclusive\n",
    "    test_x = torch.linspace(0, 1, 101)\n",
    "    # Get classification predictions\n",
    "    observed_pred = likelihood(model(test_x))\n",
    "\n",
    "    # Initialize fig and axes for plot\n",
    "    f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "    ax.plot(train_x.numpy(), train_y.numpy(), 'k*')\n",
    "    # Get the predicted labels (probabilites of belonging to the positive class)\n",
    "    # Transform these probabilities to be 0/1 labels\n",
    "    pred_labels = observed_pred.mean.ge(0.5).float()\n",
    "    ax.plot(test_x.numpy(), pred_labels.numpy(), 'b')\n",
    "    ax.set_ylim([-1, 2])\n",
    "    ax.legend(['Observed Data', 'Mean'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
